{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import rasterio\n",
    "from glob import glob\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data(base_dir):\n",
    "\n",
    "    s1_dir=glob(base_dir+'s1_validation'+'/'+'*.tif')\n",
    "    s2_dir=[]\n",
    "    lc_dir=[]\n",
    "    for i in tqdm(range(len(s1_dir))):\n",
    "        s1_filename=os.path.basename(s1_dir[i])\n",
    "        former=s1_filename.split('_s1')[0]\n",
    "        ID=s1_filename.split('_s1')[1]\n",
    "        s2_dir.append(base_dir+'s2_validation/'+former+'_s2'+ID)\n",
    "        lc_dir.append(base_dir + 'lc_validation/' + former + '_lc' + ID)\n",
    "\n",
    "    s1_dir=np.array(s1_dir)\n",
    "    s2_dir=np.array(s2_dir)\n",
    "    lc_dir=np.array(lc_dir)\n",
    "\n",
    "    return s1_dir,s2_dir,lc_dir\n",
    "\n",
    "def read_data(f1,f2):\n",
    "    \n",
    "    with rasterio.open(f1) as patch:\n",
    "        x1 = patch.read(list(range(1,3)))\n",
    "\n",
    "    with rasterio.open(f2) as patch:\n",
    "        x2 = patch.read(list(range(1,14)))\n",
    "\n",
    "    return x1,x2\n",
    "\n",
    "def clean(x1,x2):\n",
    "    # s1\n",
    "    x1[np.isnan(x1)] = 0\n",
    "    # s2\n",
    "    x2[np.isnan(x2)] = 0\n",
    "\n",
    "    # s1_recommend\n",
    "    x1[x1<-25]=-25\n",
    "    x1[x1>0]=0\n",
    "    # s2_recommend\n",
    "    x2[x2 < 0] = 0\n",
    "    x2[x2>10000]=10000\n",
    "\n",
    "    return x1.astype('float16'),x2.astype('float16')\n",
    "\n",
    "def norm(x1,x2):\n",
    "    # input,x1:[-25,0],x2:[0,10000]\n",
    "    h,w,c1=x1.shape\n",
    "    h,w,c2=x2.shape\n",
    "    x2 /= 10000 * 1.0\n",
    "    x1 = (x1 + 25) / 25 * 1.0\n",
    "    return x1,x2\n",
    "\n",
    "lab_dict={1:1,2:1,3:1,4:1,5:1,6:2,7:2,8:3,9:3,10:4,11:5,12:6,14:6,13:7,15:8,16:9,17:10}\n",
    "\n",
    "def process_label(y):\n",
    "    C,H,W=y.shape\n",
    "    y=y[[0],:,:]#1,h,w, simplified IGBP\n",
    "    y=y.reshape(-1)\n",
    "    y = list(map(lambda x: lab_dict[x], y))\n",
    "    y=np.array(y)#list->array\n",
    "    y=y.reshape(-1,H,W)\n",
    "    y-=1 # start from 0\n",
    "    return y\n",
    "\n",
    "def filter_label(x2,y):\n",
    "\n",
    "    x2[np.isnan(x2)] = 0\n",
    "\n",
    "    x2[x2>10000]=10000\n",
    "    x2[x2<0]=0\n",
    "\n",
    "    x2 = x2.astype('float')\n",
    "\n",
    "    R = x2[:, :, 3]\n",
    "    G = x2[:, :, 2]\n",
    "    B = x2[:, :, 1]\n",
    "    Nir = x2[:, :, 7]  # TM4\n",
    "    Mir = x2[:, :, 10]  # TM5\n",
    "    SWir = x2[:, :, 11]  # TM7\n",
    "\n",
    "    MSI = SWir / Nir\n",
    "    NDWI = (G - Nir) / (G + Nir)\n",
    "    NDVI = (Nir - R) / (Nir + R)\n",
    "    NDBBI = (1.5 * SWir - (Nir + G) / 2.) / (1.5 * SWir + (Nir + G) / 2.)  # 归一化差值裸地与建筑用地指数\n",
    "    BSI = ((Mir + R) - (Nir + B)) / ((Mir + R) + (Nir + B))  # 裸土指数\n",
    "    NBI = R * SWir / Nir\n",
    "\n",
    "    y_clean=y.copy()\n",
    "\n",
    "    # Forest0,Shrubland1,bg_1-2,Grassland3,Wetlands4,Croplands5,Urban6,bg2-7,Barren8,water9\n",
    "    \n",
    "    # columns\n",
    "\n",
    "    # 修正不符合要求的森林类\n",
    "    y_clean[np.where((NDVI > 0.75) & (y != 0))] = 10\n",
    "    # 修正不符合要求的灌木类\n",
    "    y_clean[np.where((NDVI > 0.2) & (NDVI < 0.35) & (MSI > 1.5) & (y != 1))] = 10\n",
    "    # 修正不符合要求的草地类\n",
    "    y_clean[np.where((NDVI > 0.4) & (NDVI < 0.55) & (y != 3))] = 10\n",
    "    # 修正不符合要求的湿地类\n",
    "    y_clean[np.where((NDVI > 0.6) & (NDVI < 0.75) & (y != 4))] = 10\n",
    "    # 修正不符合要求的农田类\n",
    "    #y_clean[np.where((NDVI > 0.2) & (NDVI < 0.35) & (MSI > 1) & (MSI < 1.5) & (y != 5))] = 10\n",
    "    # 修正不符合要求的建筑类\n",
    "    #y_clean[np.where((NDVI > 0.2) & (NDVI < 0.35) & (MSI > 0.9) & (MSI < 1) & (y != 6))] = 10\n",
    "    # 修正不符合要求的裸地类\n",
    "    y_clean[np.where((NDVI > 0) & (NDVI < 0.15) & (y != 8))] = 10\n",
    "    # 修正裸土建筑错分到其他类\n",
    "    #y_clean[np.where((BSI > -0.4) & (NDVI < 0.15) & (y != 6) & (y != 8))] = 10\n",
    "    # 修正不符合要求的水体类\n",
    "    y_clean[np.where((NDWI > 0.) & (y != 9))] = 10\n",
    "\n",
    "    # rows\n",
    "    # forest\n",
    "\n",
    "    # 修正其他类错标到森林\n",
    "    y_clean[np.where((NDVI < 0.75) & (y == 0))] = 10\n",
    "\n",
    "    # shrubland\n",
    "    \n",
    "    # savanna\n",
    "\n",
    "    # 将热带草原标签修正为草地\n",
    "    #y_clean[np.where((NDVI > 0.4) & (NDVI < 0.55) & (y == 2) & (np.sum(y == 9) < 2000))] = 3\n",
    "    # 将热带草原标签修正为湿地\n",
    "    y_clean[np.where((NDVI > 0.6) & (NDVI < 0.75) & (y == 2) & (np.sum(y == 9) > 10000))] = 4\n",
    "    # 将热带草原标签修正为灌木\n",
    "    y_clean[np.where((NDVI > 0.2) & (NDVI < 0.35) & (MSI > 1.5) & (y == 2))]=1\n",
    "    # 将热带草原标签修正为裸地\n",
    "    y_clean[np.where((NBI > 750) &(NDVI<0.2)& (NDVI>0) & (y == 2))] = 8\n",
    "  \n",
    "    # grassland\n",
    "\n",
    "    # 将草地标签修正为湿地\n",
    "    y_clean[np.where((NDVI > 0.6) & (NDVI < 0.75) & (y == 3) & (np.sum(y==9)>10000))] = 4\n",
    "\n",
    "    # wetland\n",
    "    # 将湿地修正为森林\n",
    "    #y_clean[np.where((NDVI > 0.75) & (y == 4))] = 0\n",
    "\n",
    "    y_clean[y_clean==2]=10\n",
    "    y_clean[y_clean==7]=10\n",
    "\n",
    "    return y_clean"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading & Preparing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 986/986 [00:00<00:00, 290440.60it/s]\n"
     ]
    }
   ],
   "source": [
    "base_dir = '/data/PublicData/DF2020/val/'\n",
    "s1_pre, s2_pre, lc_pre=load_data(base_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "s1_val=np.zeros([s1_pre.shape[0],256,256,2],dtype='float16')\n",
    "s2_val=np.zeros([s1_pre.shape[0],256,256,10],dtype='float16')\n",
    "y_val=np.zeros([s1_pre.shape[0],1,256,256],dtype='uint8')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 986/986 [03:11<00:00,  5.15it/s]\n"
     ]
    }
   ],
   "source": [
    "for i in tqdm(range(s1_pre.shape[0])):\n",
    "    s1_name=s1_pre[i]\n",
    "    s2_name=s2_pre[i]\n",
    "    y_name=lc_pre[i]\n",
    "    x1,x2=read_data(s1_name,s2_name)\n",
    "    x1=x1.transpose(1,2,0)#h,w,c\n",
    "    x2=x2.transpose(1,2,0)\n",
    "    x2_tmp=x2.copy()\n",
    "    x2=x2[:,:,[1,2,3,4,5,6,7,10,11,12]]#remove b1,9,10\n",
    "    \n",
    "    x1,x2=clean(x1,x2)\n",
    "    \n",
    "    with rasterio.open(y_name) as patch:\n",
    "        y = patch.read(list(range(1, 5)))\n",
    "    y = process_label(y)\n",
    "    y = filter_label(x2_tmp, y)\n",
    "    x1, x2 = norm(x1, x2)\n",
    "    \n",
    "    s1_val[i,:,:,:]=x1\n",
    "    s2_val[i,:,:,:]=x2\n",
    "    y_val[i,:,:,:]=y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0,  1,  3,  4,  5,  6,  8,  9, 10], dtype=uint8)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(y_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "merge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "s_val=np.concatenate((s1_val,s2_val),axis=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(986, 256, 256, 12)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s_val.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "s_val=s_val.reshape(-1,12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(64618496, 12)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s_val.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_val=y_val.reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(64618496,)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_val.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "filter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx=y_val!=10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "s_val=s_val[idx]\n",
    "y_val=y_val[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(32933473, 12)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s_val.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/di.wang/.conda/envs/pytorch/lib/python3.5/site-packages/xgboost/__init__.py:28: FutureWarning: Python 3.5 support is deprecated; XGBoost will require Python 3.6+ in the near future. Consider upgrading to Python 3.6+.\n",
      "  FutureWarning)\n"
     ]
    }
   ],
   "source": [
    "import xgboost as xgb\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_trn,X_tes,y_trn,y_tes =train_test_split(s_val,y_val,test_size=0.4, random_state=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = xgb.XGBClassifier(\n",
    "    n_estimators=100,\n",
    "    max_depth=10,\n",
    "    tree_method='gpu_hist' , # THE MAGICAL PARAMETER\n",
    "    learning_rate=0.1,\n",
    "    n_gpus = 1 ,# -1表示使用所有GPU\n",
    "    gpu_id = 3 , # 从GPU 1 开始\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[16:39:55] WARNING: /workspace/include/xgboost/generic_parameters.h:36: \n",
      "n_gpus: \n",
      "\tDeprecated. Single process multi-GPU training is no longer supported.\n",
      "\tPlease switch to distributed training with one process per GPU.\n",
      "\tThis can be done using Dask or Spark.  See documentation for details.\n",
      "[0]\tvalidation_0-mlogloss:1.73194\n",
      "Will train until validation_0-mlogloss hasn't improved in 10 rounds.\n",
      "[1]\tvalidation_0-mlogloss:1.49609\n",
      "[2]\tvalidation_0-mlogloss:1.31663\n",
      "[3]\tvalidation_0-mlogloss:1.17257\n",
      "[4]\tvalidation_0-mlogloss:1.05277\n",
      "[5]\tvalidation_0-mlogloss:0.95156\n",
      "[6]\tvalidation_0-mlogloss:0.86431\n",
      "[7]\tvalidation_0-mlogloss:0.78860\n",
      "[8]\tvalidation_0-mlogloss:0.72225\n",
      "[9]\tvalidation_0-mlogloss:0.66376\n",
      "[10]\tvalidation_0-mlogloss:0.61206\n",
      "[11]\tvalidation_0-mlogloss:0.56595\n",
      "[12]\tvalidation_0-mlogloss:0.52481\n",
      "[13]\tvalidation_0-mlogloss:0.48811\n",
      "[14]\tvalidation_0-mlogloss:0.45522\n",
      "[15]\tvalidation_0-mlogloss:0.42566\n",
      "[16]\tvalidation_0-mlogloss:0.39902\n",
      "[17]\tvalidation_0-mlogloss:0.37506\n",
      "[18]\tvalidation_0-mlogloss:0.35336\n",
      "[19]\tvalidation_0-mlogloss:0.33377\n",
      "[20]\tvalidation_0-mlogloss:0.31616\n",
      "[21]\tvalidation_0-mlogloss:0.30020\n",
      "[22]\tvalidation_0-mlogloss:0.28572\n",
      "[23]\tvalidation_0-mlogloss:0.27268\n",
      "[24]\tvalidation_0-mlogloss:0.26082\n",
      "[25]\tvalidation_0-mlogloss:0.25006\n",
      "[26]\tvalidation_0-mlogloss:0.24027\n",
      "[27]\tvalidation_0-mlogloss:0.23135\n",
      "[28]\tvalidation_0-mlogloss:0.22329\n",
      "[29]\tvalidation_0-mlogloss:0.21594\n",
      "[30]\tvalidation_0-mlogloss:0.20921\n",
      "[31]\tvalidation_0-mlogloss:0.20307\n",
      "[32]\tvalidation_0-mlogloss:0.19744\n",
      "[33]\tvalidation_0-mlogloss:0.19239\n",
      "[34]\tvalidation_0-mlogloss:0.18773\n",
      "[35]\tvalidation_0-mlogloss:0.18351\n",
      "[36]\tvalidation_0-mlogloss:0.17966\n",
      "[37]\tvalidation_0-mlogloss:0.17611\n",
      "[38]\tvalidation_0-mlogloss:0.17287\n",
      "[39]\tvalidation_0-mlogloss:0.16989\n",
      "[40]\tvalidation_0-mlogloss:0.16718\n",
      "[41]\tvalidation_0-mlogloss:0.16466\n",
      "[42]\tvalidation_0-mlogloss:0.16239\n",
      "[43]\tvalidation_0-mlogloss:0.16027\n",
      "[44]\tvalidation_0-mlogloss:0.15834\n",
      "[45]\tvalidation_0-mlogloss:0.15651\n",
      "[46]\tvalidation_0-mlogloss:0.15484\n",
      "[47]\tvalidation_0-mlogloss:0.15330\n",
      "[48]\tvalidation_0-mlogloss:0.15187\n",
      "[49]\tvalidation_0-mlogloss:0.15057\n",
      "[50]\tvalidation_0-mlogloss:0.14929\n",
      "[51]\tvalidation_0-mlogloss:0.14816\n",
      "[52]\tvalidation_0-mlogloss:0.14708\n",
      "[53]\tvalidation_0-mlogloss:0.14609\n",
      "[54]\tvalidation_0-mlogloss:0.14518\n",
      "[55]\tvalidation_0-mlogloss:0.14436\n",
      "[56]\tvalidation_0-mlogloss:0.14358\n",
      "[57]\tvalidation_0-mlogloss:0.14286\n",
      "[58]\tvalidation_0-mlogloss:0.14219\n",
      "[59]\tvalidation_0-mlogloss:0.14150\n",
      "[60]\tvalidation_0-mlogloss:0.14085\n",
      "[61]\tvalidation_0-mlogloss:0.14027\n",
      "[62]\tvalidation_0-mlogloss:0.13971\n",
      "[63]\tvalidation_0-mlogloss:0.13915\n",
      "[64]\tvalidation_0-mlogloss:0.13860\n",
      "[65]\tvalidation_0-mlogloss:0.13812\n",
      "[66]\tvalidation_0-mlogloss:0.13764\n",
      "[67]\tvalidation_0-mlogloss:0.13719\n",
      "[68]\tvalidation_0-mlogloss:0.13677\n",
      "[69]\tvalidation_0-mlogloss:0.13634\n",
      "[70]\tvalidation_0-mlogloss:0.13596\n",
      "[71]\tvalidation_0-mlogloss:0.13554\n",
      "[72]\tvalidation_0-mlogloss:0.13514\n",
      "[73]\tvalidation_0-mlogloss:0.13478\n",
      "[74]\tvalidation_0-mlogloss:0.13443\n",
      "[75]\tvalidation_0-mlogloss:0.13409\n",
      "[76]\tvalidation_0-mlogloss:0.13375\n",
      "[77]\tvalidation_0-mlogloss:0.13343\n",
      "[78]\tvalidation_0-mlogloss:0.13316\n",
      "[79]\tvalidation_0-mlogloss:0.13281\n",
      "[80]\tvalidation_0-mlogloss:0.13255\n",
      "[81]\tvalidation_0-mlogloss:0.13233\n",
      "[82]\tvalidation_0-mlogloss:0.13211\n",
      "[83]\tvalidation_0-mlogloss:0.13188\n",
      "[84]\tvalidation_0-mlogloss:0.13165\n",
      "[85]\tvalidation_0-mlogloss:0.13138\n",
      "[86]\tvalidation_0-mlogloss:0.13116\n",
      "[87]\tvalidation_0-mlogloss:0.13096\n",
      "[88]\tvalidation_0-mlogloss:0.13077\n",
      "[89]\tvalidation_0-mlogloss:0.13054\n",
      "[90]\tvalidation_0-mlogloss:0.13036\n",
      "[91]\tvalidation_0-mlogloss:0.13015\n",
      "[92]\tvalidation_0-mlogloss:0.12993\n",
      "[93]\tvalidation_0-mlogloss:0.12979\n",
      "[94]\tvalidation_0-mlogloss:0.12958\n",
      "[95]\tvalidation_0-mlogloss:0.12941\n",
      "[96]\tvalidation_0-mlogloss:0.12924\n",
      "[97]\tvalidation_0-mlogloss:0.12908\n",
      "[98]\tvalidation_0-mlogloss:0.12891\n",
      "[99]\tvalidation_0-mlogloss:0.12874\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "XGBClassifier(base_score=0.5, booster=None, colsample_bylevel=1,\n",
       "       colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=3,\n",
       "       importance_type='gain', interaction_constraints=None,\n",
       "       learning_rate=0.1, max_delta_step=0, max_depth=10,\n",
       "       min_child_weight=1, missing=nan, monotone_constraints=None,\n",
       "       n_estimators=100, n_gpus=1, n_jobs=0, num_parallel_tree=1,\n",
       "       objective='multi:softprob', random_state=0, reg_alpha=0,\n",
       "       reg_lambda=1, scale_pos_weight=None, subsample=1,\n",
       "       tree_method='gpu_hist', validate_parameters=False, verbosity=None)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf.fit(X_trn, y_trn, early_stopping_rounds=10, eval_metric=\"mlogloss\",\n",
    "        eval_set=[(X_tes, y_tes)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.953671454348501"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pre=clf.predict(X_tes)\n",
    "np.mean(y_pre==y_tes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "checking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 1 3 4 5 6 8 9]\n",
      "[0 1 3 4 5 6 8 9]\n"
     ]
    }
   ],
   "source": [
    "print(np.unique(y_trn))\n",
    "print(np.unique(y_tes))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "saving"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.externals import joblib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['xgb_val_sa_shrb.pkl']"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "joblib.dump(clf,'xgb_val_sa_shrb.pkl')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predicting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "hex_color_dict={10:'000000',0:'009900',1:'c6b044',2:'fbff13',3:'b6ff05',4:'27ff87',5:'c24f44',\n",
    "                    6:'a5a5a5',7:'69fff8',8:'f9ffa4',9:'1c0dff'}\n",
    "\n",
    "def Hex_to_RGB(str):\n",
    "    r = int(str[0:2],16)\n",
    "    g = int(str[2:4],16)\n",
    "    b = int(str[4:6],16)\n",
    "    return [r,g,b]\n",
    "\n",
    "def DrawResult(labels, row, col):\n",
    "    num_class = 10\n",
    "\n",
    "    X_result = np.zeros((labels.shape[0], 3))\n",
    "    for i in range(num_class):\n",
    "        X_result[np.where(labels == i), 0] = Hex_to_RGB(hex_color_dict[i])[0]\n",
    "        X_result[np.where(labels == i), 1] = Hex_to_RGB(hex_color_dict[i])[1]\n",
    "        X_result[np.where(labels == i), 2] = Hex_to_RGB(hex_color_dict[i])[2]\n",
    "\n",
    "    X_result = np.reshape(X_result, (row, col, 3))\n",
    "\n",
    "    return X_result\n",
    "\n",
    "def Cal_INDEX(x):\n",
    "\n",
    "    x=x.astype('float')\n",
    "\n",
    "    x[x>10000]=10000\n",
    "    x[x<0]=0\n",
    "\n",
    "    R = x[:, :, 3]\n",
    "    G = x[:, :, 2]\n",
    "    Nir = x[:, :, 7]  # TM4\n",
    "    Mir = x[:, :, 10]  # TM5\n",
    "    SWir= x[:,:,11]\n",
    "\n",
    "    NDWI = (G - Nir) / (G + Nir)\n",
    "    NDVI = (Nir - R) / (Nir + R)\n",
    "    NDSI = (Mir-Nir) / (Mir+Nir)\n",
    "    NBI = R * SWir / Nir\n",
    "\n",
    "    return NDWI,NDVI, NDSI, NBI\n",
    "\n",
    "class Evaluator(object):\n",
    "    def __init__(self, num_class):\n",
    "        self.num_class = num_class\n",
    "        self.confusion_matrix = np.zeros((self.num_class,)*2)\n",
    "        # matrix shape(num_class, num_class) with elements 0 in our match. it will be 4*4\n",
    "\n",
    "    def Kappa(self):        \n",
    "        \n",
    "        xsum = np.sum(self.confusion_matrix, axis=1)  # sum by row\n",
    "        ysum = np.sum(self.confusion_matrix, axis=0)  # sum by column\n",
    "       \n",
    "        Pe = np.sum(ysum*xsum)*1.0/(self.confusion_matrix.sum()**2)\n",
    "        P0 = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()  # predict right / all the data\n",
    "        cohens_coefficient = (P0-Pe)/(1-Pe)\n",
    "\n",
    "        return cohens_coefficient\n",
    "            \n",
    "    def ProducerA(self):\n",
    "        #\n",
    "        return np.diag(self.confusion_matrix) / np.sum(self.confusion_matrix, axis=1)\n",
    "\n",
    "    def UserA(self):\n",
    "        #\n",
    "        return np.diag(self.confusion_matrix) / np.sum(self.confusion_matrix, axis=0)\n",
    "\n",
    "\n",
    "    def Pixel_Accuracy(self):\n",
    "        Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()\n",
    "        return Acc\n",
    "\n",
    "    def val_Pixel_Accuracy_Class(self):\n",
    "        Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)\n",
    "        # each pred right class is in diag. sum by row is the count of corresponding class\n",
    "#         index=np.array([0,1])\n",
    "#         Acc=Acc[index]\n",
    "        #Acc[-1]=90\n",
    "        Acc = np.nanmean(Acc) #\n",
    "        return Acc\n",
    "\n",
    "    def pre_Pixel_Accuracy_Class(self):\n",
    "        Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)\n",
    "        # each pred right class is in diag. sum by row is the count of corresponding class\n",
    "#         index=np.array([0,1])\n",
    "#         Acc=Acc[index]\n",
    "        #Acc[-1]=90\n",
    "        Acc = np.nanmean(Acc) #\n",
    "        return Acc\n",
    "\n",
    "    def Mean_Intersection_over_Union(self):\n",
    "        MIoU = np.diag(self.confusion_matrix) / (\n",
    "                    np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -\n",
    "                    np.diag(self.confusion_matrix))\n",
    "        MIoU = np.nanmean(MIoU)\n",
    "        return MIoU\n",
    "\n",
    "    def Frequency_Weighted_Intersection_over_Union(self):\n",
    "        freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)\n",
    "        iu = np.diag(self.confusion_matrix) / (\n",
    "                    np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -\n",
    "                    np.diag(self.confusion_matrix))\n",
    "\n",
    "        FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()\n",
    "        return FWIoU\n",
    "\n",
    "    def _generate_matrix(self, gt_image, pre_image):\n",
    "        # gt_image = batch_size*256*256   pre_image = batch_size*256*256\n",
    "        mask = (gt_image >= 0) & (gt_image < self.num_class)  # valid in mask show True, ignored in mask show False\n",
    "        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]\n",
    "        # gt_image[mask] : find out valid pixels. elements with 0,1,2,3 , so label range in  0-15\n",
    "        count = np.bincount(label, minlength=self.num_class**2)\n",
    "        # [0, 1, 2, 3,  confusion_matrix like this:\n",
    "        #  4, 5, 6, 7,  and if the element is on the diagonal, it means predict the right class.\n",
    "        #  8, 9, 10,11, row means the real label, column means pred label\n",
    "        #  12,13,14,15]\n",
    "        # return a array [a,b....], each letters holds the count of a class and map to class0, class1...\n",
    "        confusion_matrix = count.reshape(self.num_class, self.num_class)\n",
    "        return confusion_matrix\n",
    "\n",
    "    def add_batch(self, gt_image, pre_image):\n",
    "        assert gt_image.shape == pre_image.shape\n",
    "        self.confusion_matrix += self._generate_matrix(gt_image, pre_image)\n",
    "\n",
    "    def reset(self):\n",
    "        self.confusion_matrix = np.zeros((self.num_class,) * 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.misc import imsave\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 986/986 [00:00<00:00, 198406.44it/s]\n"
     ]
    }
   ],
   "source": [
    "base_dir = '/data/PublicData/DF2020/val/'\n",
    "s1_pre, s2_pre, lc_pre=load_data(base_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf=joblib.load('xgb_val_sa_shrb.pkl') "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluator=Evaluator(10)\n",
    "main_dir='/data/di.wang/ordinary/23DCNN/DataFusion_2020_new/preval/xgb_filter_sa_shrb/'\n",
    "\n",
    "outputdir=main_dir+'output/'\n",
    "visdir=main_dir+'vis/'\n",
    "\n",
    "if not os.path.exists(outputdir):\n",
    "    os.makedirs(outputdir)\n",
    "\n",
    "if not os.path.exists(visdir):\n",
    "    os.makedirs(visdir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/986 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[16:58:16] WARNING: /workspace/include/xgboost/generic_parameters.h:36: \n",
      "n_gpus: \n",
      "\tDeprecated. Single process multi-GPU training is no longer supported.\n",
      "\tPlease switch to distributed training with one process per GPU.\n",
      "\tThis can be done using Dask or Spark.  See documentation for details.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/di.wang/.conda/envs/pytorch/lib/python3.5/site-packages/ipykernel_launcher.py:53: DeprecationWarning: `imsave` is deprecated!\n",
      "`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.\n",
      "Use ``imageio.imwrite`` instead.\n",
      "100%|██████████| 986/986 [14:22<00:00,  1.14it/s]\n"
     ]
    }
   ],
   "source": [
    "for i in tqdm(range(s1_pre.shape[0])):\n",
    "    s1_name=s1_pre[i]\n",
    "    s2_name=s2_pre[i]\n",
    "    y_name=lc_pre[i]\n",
    "    x1,x2=read_data(s1_name,s2_name)\n",
    "    x1=x1.transpose(1,2,0)#h,w,c\n",
    "    x2=x2.transpose(1,2,0)\n",
    "    x2_tmp=x2.copy()\n",
    "    x2=x2[:,:,[1,2,3,4,5,6,7,10,11,12]]#remove b1,9,10\n",
    "    \n",
    "    x1,x2=clean(x1,x2)\n",
    "    \n",
    "    with rasterio.open(y_name) as patch:\n",
    "        y = patch.read(list(range(1, 5)))\n",
    "    y = process_label(y)\n",
    "    y = filter_label(x2_tmp, y)\n",
    "    x1, x2 = norm(x1, x2)\n",
    "    \n",
    "    x1=x1.reshape(-1,2)\n",
    "    x2=x2.reshape(-1,10)\n",
    "    x=np.concatenate((x1,x2),axis=1)\n",
    "    y_pre=clf.predict(x).reshape(256,256)\n",
    "    \n",
    "    # postprocessing\n",
    "    NDWI,NDVI,NDSI,NBI=Cal_INDEX(x2_tmp)\n",
    "    im = np.uint8(y_pre + 1)\n",
    "    \n",
    "    y_tmp=y[0].copy()\n",
    "    im_tmp=im.copy()\n",
    "    y_pre_tmp=y_pre.copy()\n",
    "    \n",
    "    # grassland\n",
    "    \n",
    "#     im[np.where((NDVI>0.4) & (NDVI<0.6)& (y_tmp==3) &(np.sum(im_tmp==10)<2000))]=4\n",
    "#     y_pre[np.where((NDVI>0.4) & (NDVI<0.6)&(y_tmp==3) &(np.sum(y_pre_tmp==9)<2000))]=3\n",
    "    \n",
    "#     # wetland\n",
    "    \n",
    "#     im[np.where((NDVI>0.6) & (NDVI<0.75)& ((im_tmp==4)|(im_tmp==1))&(np.sum(im_tmp==10)>2000))]=5\n",
    "#     y_pre[np.where((NDVI>0.6) & (NDVI<0.75)& ((im_tmp==4)|(im_tmp==1)) &(np.sum(y_pre_tmp==9)>2000))]=4\n",
    "    \n",
    "#     # barren\n",
    "    \n",
    "#     im[np.where((NBI > 750) &(NDVI<0.4)& (NDVI>0) & (y_tmp!=5) & (y_tmp!=6) &((im_tmp==1) | (im_tmp==4) | (im_tmp==2)))] = 9\n",
    "#     y_pre[np.where((NBI > 750) &(NDVI<0.4) & (NDVI>0) & (y_tmp!=5) & (y_tmp!=6)& ((y_pre_tmp==0) | (y_pre_tmp==3) | (y_pre_tmp==1)))] = 8\n",
    "    \n",
    "    evaluator.add_batch(y, y_pre[np.newaxis, :, :])\n",
    "    \n",
    "    filename=os.path.basename(y_name)\n",
    "    former = filename.split('lc')[0]\n",
    "    latter = filename.split('lc')[1]\n",
    "    \n",
    "    imsave(outputdir+former+'dfc'+latter,im)\n",
    "    im_rgb = Image.fromarray(np.uint8(DrawResult(y_pre.reshape(-1), 256, 256)))\n",
    "    im_rgb.save(visdir + former+'dfc'+latter[:-4] + '_vis.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AVERAGE ACCURACY of val DATASET: 0.8536393893844794\n",
      "CONFUSION MATRIX\n",
      "[[2.7735480e+06 0.0000000e+00 0.0000000e+00 0.0000000e+00 9.8600000e+02\n",
      "  0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]\n",
      " [0.0000000e+00 4.3074800e+05 0.0000000e+00 1.4594200e+05 4.4000000e+01\n",
      "  4.7560000e+04 8.5040000e+03 0.0000000e+00 4.9410000e+03 2.0000000e+01]\n",
      " [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      "  0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]\n",
      " [2.8000000e+01 8.1360000e+04 0.0000000e+00 1.5632350e+06 2.0100000e+03\n",
      "  3.0959300e+05 1.4074400e+05 0.0000000e+00 1.9313000e+04 1.6832000e+04]\n",
      " [1.0340000e+03 2.2900000e+02 0.0000000e+00 1.1291000e+04 1.9335440e+06\n",
      "  2.4760000e+03 1.0240000e+04 0.0000000e+00 1.4320000e+03 7.3650000e+03]\n",
      " [1.4000000e+01 1.0249000e+04 0.0000000e+00 1.3655100e+05 2.8600000e+02\n",
      "  3.1404870e+06 1.2036800e+05 0.0000000e+00 5.0240000e+03 4.0300000e+02]\n",
      " [3.0000000e+01 3.9880000e+03 0.0000000e+00 9.5826000e+04 7.6100000e+02\n",
      "  1.3976700e+05 1.3666470e+06 0.0000000e+00 1.5426000e+04 3.2400000e+03]\n",
      " [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      "  0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]\n",
      " [0.0000000e+00 3.7030000e+03 0.0000000e+00 1.8776000e+04 7.0000000e+01\n",
      "  2.6183000e+04 6.1810000e+04 0.0000000e+00 2.3973800e+05 2.1320000e+03]\n",
      " [1.2000000e+01 4.9000000e+01 0.0000000e+00 1.9101000e+04 1.0850000e+03\n",
      "  1.8460000e+03 2.0871000e+04 0.0000000e+00 6.1870000e+03 1.9979824e+07]]\n",
      "ACCURACY IN EACH CLASSES: [0.99964463 0.67540874        nan 0.73284141 0.98268611 0.92005143\n",
      " 0.84065917        nan 0.68027763 0.99754601]\n",
      "Prediction finished!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/di.wang/.conda/envs/pytorch/lib/python3.5/site-packages/ipykernel_launcher.py:83: RuntimeWarning: invalid value encountered in true_divide\n",
      "/data/di.wang/.conda/envs/pytorch/lib/python3.5/site-packages/ipykernel_launcher.py:9: RuntimeWarning: invalid value encountered in true_divide\n",
      "  if __name__ == '__main__':\n"
     ]
    }
   ],
   "source": [
    "AA = evaluator.pre_Pixel_Accuracy_Class()\n",
    "\n",
    "print('AVERAGE ACCURACY of val DATASET: {}'.format(AA))\n",
    "\n",
    "print('CONFUSION MATRIX')\n",
    "\n",
    "print(evaluator.confusion_matrix)\n",
    "\n",
    "print('ACCURACY IN EACH CLASSES:',np.diag(evaluator.confusion_matrix) / evaluator.confusion_matrix.sum(axis=1))\n",
    "\n",
    "print('Prediction finished!')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "pytorch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
